K-means

library("dslabs")
library("ggplot2")
library("knitr")
library("tidymodels")
library("tidyr")
library("dplyr")
theme_set(theme_minimal())
  1. The goal of clustering is to discover distinct groups within a dataset. In an ideal clustering, samples are very different between groups, but relatively similar within groups. At the end of a clustering routine, \(K\) clusters have been identified, and each sample is assigned to one of these \(K\) clusters. \(K\) must be chosen by the user.

  2. Clustering gives a compressed representation of the dataset. Therefore, clustering is useful for getting a quick overview of the high-level structure in a dataset.

  3. For example, clustering can be used in the following applications,

  1. \(K\)-means is a particular algorithm for finding clusters. First, it randomly initializes \(K\) cluster centroids. Then, it alternates the following two steps until convergence,
    • Assign points to their nearest cluster centroid.
    • Update the \(K\) centroids to be the averages of points within their cluster.

Here is an animation from the tidymodels page on \(K\)-means,

  1. Note that, since we have to take an average for each coordinate, we require that our data be quantitative, not categorical.

  2. We illustrate this idea using the movielens dataset from the reading. This dataset has ratings (0.5 to 5) given by 671 users across 9066 movies. We can think of this as a matrix of movies vs. users, with ratings within the entries. For simplicity, we filter down to only the 50 most frequently rated movies. We will assume that if a user never rated a movie, they would have given that movie a zero. We’ve skipped a few steps used in the reading (subtracting movie / user averages and filtering to only active users), but the overall results are comparable.

data("movielens")
frequently_rated <- movielens %>%
  group_by(movieId) %>%
  summarize(n=n()) %>%
  top_n(50, n) %>%
  pull(movieId)
movie_mat <- movielens %>% 
  filter(movieId %in% frequently_rated) %>%
  select(title, userId, rating) %>%
  pivot_wider(title, names_from = userId, values_from = rating, values_fill = 0)
movie_mat[1:10, 1:20]
## # A tibble: 10 x 20
##    title   `2`   `3`   `4`   `5`   `6`   `7`   `8`   `9`  `10`  `11`  `12`  `13`
##    <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
##  1 Seve~     4   0       0     0     0     0   5       3     0     0     0   2.5
##  2 Usua~     4   0       0     0     0     0   5       0     5     5     0   0  
##  3 Brav~     4   4       0     0     0     5   4       0     0     0     0   4  
##  4 Apol~     5   0       0     4     0     0   0       0     0     0     0   0  
##  5 Pulp~     4   4.5     5     0     0     0   4       0     0     5     0   3.5
##  6 Forr~     3   5       5     4     0     3   4       0     0     0     0   5  
##  7 Lion~     3   0       5     4     0     3   0       0     0     0     0   0  
##  8 Mask~     3   0       4     4     0     3   0       0     0     0     0   0  
##  9 Speed     3   2.5     0     4     0     3   0       0     0     0     0   0  
## 10 Fugi~     3   0       0     0     0     0   4.5     0     0     0     0   0  
## # ... with 7 more variables: 14 <dbl>, 15 <dbl>, 16 <dbl>, 17 <dbl>, 18 <dbl>,
## #   19 <dbl>, 20 <dbl>
  1. Next, we run kmeans on this dataset. I’ve used the dplyr pipe notation to run kmeans on the data above with “title” removed. augment is a function from the tidymodels package that adds the cluster labels identified by kmeans to the rows in the original dataset.
kclust <- movie_mat %>%
  select(-title) %>%
  kmeans(centers = 10)
movie_mat <- augment(kclust, movie_mat) # creates column ".cluster" with cluster label
kclust <- tidy(kclust)
movie_mat %>%
  select(title, .cluster) %>%
  arrange(.cluster)
## # A tibble: 50 x 2
##    title                      .cluster
##    <chr>                      <fct>   
##  1 Forrest Gump               1       
##  2 Schindler's List           1       
##  3 Silence of the Lambs, The  1       
##  4 Braveheart                 2       
##  5 Apollo 13                  2       
##  6 Speed                      2       
##  7 Fugitive, The              2       
##  8 Jurassic Park              2       
##  9 Terminator 2: Judgment Day 2       
## 10 Dances with Wolves         2       
## # ... with 40 more rows
  1. There are two pieces of derived data generated by this routine,
    • The cluster assignments
    • The cluster centroids and both can be the subjects of visualization.
  2. In our movie example, the cluster centroids are imaginary pseudo-movies that are representative of their cluster. They are represented by the scores they would have received by each of the users in the dataset. This is visualized below. In a more realistic application, we would also want to display some information about each user; e.g., maybe some movies are more popular among certain age groups or in certain regions.
kclust_long <- kclust %>%
  pivot_longer(`2`:`671`, names_to = "userId", values_to = "rating")
ggplot(kclust_long) +
  geom_bar(
    aes(x = reorder(userId, rating), y = rating),
    stat = "identity"
  ) +
  facet_grid(cluster ~ .) +
  labs(x = "Users (sorted)", y = "Rating") +
  theme(
    axis.text.x = element_blank(),
    axis.text.y = element_text(size = 5),
    strip.text.y = element_text(angle = 0)
  )
We can visualize each cluster by seeing the average ratings each user gave to the movies in that cluster (this is the definition of the centroid). An alternative visualization strategy would be to show a heatmap -- we'll discuss this soon in the superheat lecture.

We can visualize each cluster by seeing the average ratings each user gave to the movies in that cluster (this is the definition of the centroid). An alternative visualization strategy would be to show a heatmap – we’ll discuss this soon in the superheat lecture.

  1. It’s often of interest to relate the cluster assignments to complementary data, to see whether the clustering reflects any previously known differences between the observations, which weren’t directly used in the clustering algorithm.

  2. Be cautious: Outliers, nonspherical shapes, and variations in density can throw off \(K\)-means.

The difficulty that variations in density poses to k-means, from Cluster Analysis using K-Means Explained.

  1. The goals of clustering are highly problem dependent, and different goals might call for alternative algorithms. For example, consider the ways clustering might be used to understand disease transmission. One problem might be to cluster the DNA sequences of the pathogenic agent, to recover its evolutionary history. This could be done using hierarchical clustering (next lecture). A second problem might be to determine whether patient outcomes might be driven by one of a few environmental factors, in which case a \(K\)-means clustering across the typical environmental factors would be reasonable. A third use would be to perform contact tracing, based on a network clustering algorithm. The point is that no clustering algorithm is uniformly better than any other in all situations, and the choice of which one to use should be guided by the problem requirements.

Hierarchical Clustering

library("dplyr")
library("ggplot2")
library("ggraph")
library("knitr")
library("readr")
library("robservable")
library("tidygraph")
theme_set(theme_graph())
  1. In reality, data are rarely separated into a clear number of homogeneous clusters. More often, even once a cluster formed, it’s possible to identify a few subclusters. For example, if you initially clustered movies into “drama” and “scifi”, you might be able to further refine the scifi cluster into “time travel” and “aliens.”

  2. \(K\)-means only allows clustering at a single level of magnification. To instead simultaneously cluster across scales, you can use an approach called hierarchical clustering. As a first observation, note that a tree can be used to implicitly store many clusterings at once. You can get a standard clustering by cutting the tree at some level.

We can recover clusters at different levels of granularity, by cutting a hierarchical clustering tree.

  1. These hierarchical clustering trees can be thought of abstract versions of the taxonomic trees. Instead of relating species, they relate observations in a dataset.
robservable("@mbostock/tree-of-life", height = 1150)
  1. Elaborating on this analogy, the leaves of a hierarchical clustering tree are the original observations. The more recently two nodes share a common ancestor, the more similar those observations are.

  2. The specific algorithm proceeds as follows,

    • Initialize: Associate each point with a cluster \(C_i := \{x_i\}\).
    • Iterate until only one cluster: Look at all pairs of clusters. Merge the pair \(C_k, C_{k^{\prime}}\) which are the most similar.

At initialization, the hierarchical clustering routine has a cluster for each observation.

Next, the two closest observations are merged into one cluster. This is the first merge point on the tree.

We continue this at the next iteration, though this time we have compute the pairwise distance between all clusters, not observations (technically, all the observations were their own cluster at the first step, and in both cases, we compare the pairwise distances between clusters).

We can continue this process…

… and eventually we will construct the entire tree.

  1. In R, this can be accomplished by using the hclust function. First, we compute the distances between all pairs of observations (this provides the similarities used in the algorithm). Then, we apply hclust to the matrix of pairwise distances.

  2. We apply this to a movie ratings dataset. Movies are considered similar if they tend to receive similar ratings across all audience members. The result is visualized below.

movies_mat <- read_csv("https://uwmadison.box.com/shared/static/wj1ln9xtigaoubbxow86y2gqmqcsu2jk.csv")
## 
## -- Column specification --------------------------------------------------------
## cols(
##   .default = col_double(),
##   title = col_character()
## )
## i Use `spec()` for the full column specifications.
D <- movies_mat %>%
  column_to_rownames(var = "title") %>%
  dist()
hclust_result <- hclust(D)
plot(hclust_result, cex = 0.5)

  1. We can customize our tree visualization using the ggraph package. We can convert the hclust object into a ggraph, using the same as_tbl_graph function from the network and trees lectures.
hclust_graph <- as_tbl_graph(hclust_result, height = height)
hclust_graph <- hclust_graph %>%
  mutate(height = ifelse(height == 0, 27, height)) # shorten the final edge
hclust_graph
## # A tbl_graph: 99 nodes and 98 edges
## #
## # A rooted tree
## #
## # Node Data: 99 x 4 (active)
##   height leaf  label                       members
##    <dbl> <lgl> <chr>                         <int>
## 1   27   TRUE  "Schindler's List"                1
## 2   27   TRUE  "Forrest Gump"                    1
## 3   27   TRUE  "Shawshank Redemption, The"       1
## 4   27   TRUE  "Pulp Fiction"                    1
## 5   27   TRUE  "Silence of the Lambs, The"       1
## 6   58.7 FALSE ""                                2
## # ... with 93 more rows
## #
## # Edge Data: 98 x 2
##    from    to
##   <int> <int>
## 1     6     4
## 2     6     5
## 3     7     3
## # ... with 95 more rows
ggraph(hclust_graph, "dendrogram", height = height, circular = TRUE) +
  geom_edge_elbow() +
  geom_node_text(aes(label = label), size = 4) +
  coord_fixed()

  1. We can cut the tree to recover a standard clustering. This is where the grammar-of-graphics approach from ggraph becomes useful – we can encode the cluster membership of a movie using color, for example.
cluster_df <- cutree(hclust_result, k = 10) %>% # try changing K and regenerating the graph below
  tibble(label = names(.), cluster = as.factor(.))
cluster_df
## # A tibble: 50 x 3
##        . label                cluster
##    <int> <chr>                <fct>  
##  1     1 Seven (a.k.a. Se7en) 1      
##  2     1 Usual Suspects, The  1      
##  3     2 Braveheart           2      
##  4     2 Apollo 13            2      
##  5     3 Pulp Fiction         3      
##  6     4 Forrest Gump         4      
##  7     2 Lion King, The       2      
##  8     2 Mask, The            2      
##  9     2 Speed                2      
## 10     2 Fugitive, The        2      
## # ... with 40 more rows
# colors chosen using https://medialab.github.io/iwanthue/
cols <- c("#51b48c", "#cf3d6e", "#7ab743", "#7b62cb", "#c49644", "#c364b9", "#6a803a", "#688dcd", "#c95a38", "#c26b7e")
hclust_graph %>%
  left_join(cluster_df) %>%
  ggraph("dendrogram", height = height, circular = TRUE) +
  geom_edge_elbow() +
  geom_node_text(aes(label = label, col = cluster), size = 4) +
  coord_fixed() +
  scale_color_manual(values = cols) +
  theme(legend.position = "none")
## Joining, by = "label"

Heatmaps

library("dplyr")
library("ggplot2")
library("readr")
library("superheat")
library("tibble")
theme_set(theme_minimal())
  1. The direct outputs of a standard clustering algorithim are (a) cluster assignments for each sample, (b) the centroids associated with each cluster. A hierarchical clustering algorithm enriches this output with a tree, which provide (a) and (b) at multiple levels of resolution.

  2. These outputs can be used to improve visualizations. For example, they can be used to define small multiples, faceting across clusters. One especially common idea is to reorder the rows of a heatmap using the results of a clustering, and this is the subject of these notes.

  3. In a heatmap, each mark (usually a small tile) corresponds to an entry of a matrix. The \(x\)-coordinate of the mark encodes the index of the observation, while the \(y\)-coordinate encodes the index of the feature. The color of each tile represents the value of that entry. For example, here are the first few rows of the movies data, along with the corresponding heatmap, made using the superheat package.

movies_mat <- read_csv("https://uwmadison.box.com/shared/static/wj1ln9xtigaoubbxow86y2gqmqcsu2jk.csv") %>%
  column_to_rownames(var = "title")
## 
## -- Column specification --------------------------------------------------------
## cols(
##   .default = col_double(),
##   title = col_character()
## )
## i Use `spec()` for the full column specifications.
cols <- c('#f6eff7','#bdc9e1','#67a9cf','#1c9099','#016c59')
superheat(movies_mat, left.label.text.size = 4, heat.pal = cols, heat.lim = c(0, 5))
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

  1. Just like in adjacency matrix visualizations, the effectiveness of a heatmap can depend dramatically on the way in which rows and columns are ordered. To provide a more coherent view, we cluster both rows and columns, placing rows/columns belonging to the same cluster next to one another.
movies_clust <- movies_mat %>%
  kmeans(centers = 10)
users_clust <- movies_mat %>%
  t() %>%
  kmeans(centers = 10)
superheat(
  movies_mat, 
  left.label.text.size = 4, 
  order.rows = order(movies_clust$cluster),
  order.cols = order(users_clust$cluster),
  heat.pal = cols,
  heat.lim = c(0, 5)
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

  1. superheat also makes it easy to visualize plot statistics adjacent to the adjacent to the main heatmap. These statistics can be plotted as points, lines, or bars. Points are useful when we want to highlight the raw value, lines are effective for showing change, and bars give a sense of the area below a set of observations. In this example, we use an added panel on the right hand side (yr) to encode the total number of ratings given to that movie. The yr.obs.cols allows us to change the color of each point in the adjacent plot. In this example, we change color depending on which cluster the movie was found to belong to.
cluster_cols <- c('#8dd3c7','#ccebc5','#bebada','#fb8072','#80b1d3','#fdb462','#b3de69','#fccde5','#d9d9d9','#bc80bd')
superheat(
  movies_mat, 
  left.label.text.size = 4, 
  order.rows = order(movies_clust$cluster),
  order.cols = order(users_clust$cluster),
  heat.pal = cols,
  heat.lim = c(0, 5),
  yr = rowSums(movies_mat > 0),
  yr.axis.name = "Number of Ratings",
  yr.obs.col = cluster_cols[movies_clust$cluster],
  yr.plot.type = "bar"
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

  1. It also makes sense to order the rows / columns using hierarchical clustering. This approach is especially useful when the samples fall along a continuous gradient, rather than belonging to clearly delineated groups. The pretty.order.rows and pretty.order.cols arguments use hierarchical clustering to reorder the heatmap.
superheat(
  movies_mat, 
  left.label.text.size = 4, 
  pretty.order.cols = TRUE,  
  pretty.order.rows = TRUE,
  heat.pal = cols,
  heat.lim = c(0, 5)
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

  1. The hierarchical clustering trees estimated by pretty.order.rows and pretty.order.cols can be also visualized.
superheat(
  movies_mat, 
  left.label.text.size = 4, 
  pretty.order.cols = TRUE,  
  pretty.order.rows = TRUE, 
  row.dendrogram = TRUE,
  col.dendrogram = TRUE,
  heat.pal = cols,
  heat.lim = c(0, 5)
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

Silhouette Statistics

library("cluster")
library("stringr")
library("dplyr")
library("tidymodels")
library("readr")
library("ggplot2")
theme_set(theme_bw())
set.seed(123)
  1. Clustering algorithms usually require the number of clusters \(K\) as an argument. How should it be chosen?

  2. There are many possible criteria, but one common approach is to compute the silhouette statistic. It is a statistic that can be computed for each observation in a dataset, measuring how strongly it is tied to its assigned cluster. If a whole cluster has large silhouette statistics, then that cluster is well-defined and clearly isolated other clusters.

  3. The plots below illustrate the computation of silhouette statistics for a clustering of the penguins dataset that used \(K = 3\). To set up, we first need to cluster the penguins dataset. The idea is the same as in the \(K\)-means notes, but we encapsulate the code in a function, so that we can easily extract data for different values of \(K\).

penguins <- read_csv("https://uwmadison.box.com/shared/static/ijh7iipc9ect1jf0z8qa2n3j7dgem1gh.csv") %>%
  na.omit() %>%
  mutate(id = row_number())
## 
## -- Column specification --------------------------------------------------------
## cols(
##   species = col_character(),
##   island = col_character(),
##   bill_length_mm = col_double(),
##   bill_depth_mm = col_double(),
##   flipper_length_mm = col_double(),
##   body_mass_g = col_double(),
##   sex = col_character(),
##   year = col_double()
## )
cluster_penguins <- function(penguins, K) {
  x <- penguins %>%
    select(matches("length|depth|mass")) %>%
    scale()
    
  kmeans(x, center = K) %>%
    augment(penguins) %>% # creates column ".cluster" with cluster label
    mutate(silhouette = silhouette(as.integer(.cluster), dist(x))[, "sil_width"])
}
  1. Denote the silhouette statistic of observation \(i\) by \(s_{i}\). We will compute \(s_i\) for the observation with the black highlight below1.
cur_id <- 2
penguins3 <- cluster_penguins(penguins, K = 3)
obs_i <- penguins3 %>%
  filter(id == cur_id)
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
  geom_point(data = obs_i, size = 5, col = "black") + 
  geom_point() +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1))
The observation on which we will compute the silhouette statistic.

The observation on which we will compute the silhouette statistic.

  1. The first step in the calculation of the silhouette statistic is to measure the pairwise distances between the observation \(i\) and all observations in the same cluster. These distances are the lengths of the small lines below. Call average of these lengths \(a_{i}\).
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
  geom_segment(
    data = penguins3 %>% filter(.cluster == obs_i$.cluster), 
    aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm),
    size = 0.6, alpha = 0.3
  ) +
  geom_point(data = obs_i, size = 5, col = "black") + 
  geom_point() +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1)) +
  labs(title = expression(paste("Distances used for ", a[i])))
The average distance between the target observation and all others in the same cluster.

The average distance between the target observation and all others in the same cluster.

  1. Next, we compute pairwise distances to all observations in clusters 2 and 3. The average of these pairwise distances are called \(b_{i2}\) and \(b_{i3}\). Choose the smaller of \(b_{i2}\) and \(b_{i3}\), and call it \(b_{i}\). In a sense, this is the “next best” cluster to put observation \(i\). For a general \(K\), you would compute \(b_{ik}\) for all \(k\) (other than observation \(i\)’s cluster) and take the minimum across all of them. In this case, the orange segments are on average smaller than the blue segments, so \(b_i\) is defined as the average length of the orange segments.
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
  geom_segment(
    data = penguins3 %>% filter(.cluster != obs_i$.cluster), 
    aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm, col = .cluster),
    size = 0.5, alpha = 0.3
  ) +
  geom_point(data = obs_i, size = 5, col = "black") + 
  geom_point() +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1)) +
  labs(title = expression(paste("Distances used for ", b[i][1], " and ", b[i][2])))
The average distance between the target observation and all others in *different* clusters.

The average distance between the target observation and all others in different clusters.

  1. The silhouette statistic for observation \(i\) is derived from the relative lengths of the orange vs. green segments. Formally, the silhouette statistic for observation \(i\) is \(s_{i}:= \frac{b_{i} - a_{i}}{\max\left({a_{i}, b_{i}}\right)}\). This number is close to 1 if the orange segments are much longer than the green segments, close to 0 if the segments are about the same size, and close to -1 if the the orange segments are much shorter than the green segments2.

  2. The median of these \(s_{i}\) for all observations within cluster \(k\) is a measure of how well-defined cluster \(k\) is overall. The higher this number, the more well-defined the cluster.

  3. Denote the median of the silhouette statistics within cluster \(k\) by \(SS_{k}\). A measure how good a choice of \(K\) is can be determined by the median of these medians: \(\text{Quality}(K) := \text{median}_{k = 1 \dots, K} SS_{k}\).

  4. In particular, this can be used to define (a) a good cut point in a hierarchical clustering or (b) a point at which a cluster should no longer be split into subgroups.

  5. In R, we can use the silhouette function from the cluster package to compute the silhouette statistic. The syntax is silhouette(cluster_labels, pairwise_distances) where cluster_labels is a vector of (integer) cluster ID’s for each observation and pairwise_distances gives the lengths of the segments between all pairs of observations. An example of this function’s usage is given in the function at the start of the illustration.

  6. This is what the silhouette statistic looks like in the penguins dataset when we choose 3 clusters. The larger points have lower silhouette statistics. This points between clusters 2 and 3 have large silhouette statistics because those two clusters blend into one another.

ggplot(penguins3) +
  geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1))
The silhouette statistics on the Palmers Penguins dataset, when using $K$-means with $K = 3$.

The silhouette statistics on the Palmers Penguins dataset, when using \(K\)-means with \(K = 3\).

  1. We can also visualize the histogram of silhouette statistics within each cluster. Since the silhouette statistics for cluster 2 are generally lower than those for the other two clusters (in particular, its median is lower), we can conclude that it is less well-defined.
ggplot(penguins3) +
  geom_histogram(aes(x = silhouette), binwidth = 0.05) +
  facet_grid(~ .cluster)
The per-cluster histograms of silhouette statistics summarize how well-defined each cluster is.

The per-cluster histograms of silhouette statistics summarize how well-defined each cluster is.

  1. If we choose even more clusters, then there are more points lying along the boundaries of poorly defined clusters. Their associated silhouette statistics end up becoming larger. From the histogram, we can also see a deterioration in the median silhouette scores across all clusters.
penguins4 <- cluster_penguins(penguins, K = 4)
ggplot(penguins4) +
  geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
  scale_color_brewer(palette = "Set2") +
  scale_size(range = c(4, 1))
We can repeat the same exercise, but with $K = 4$ clusters instead.

We can repeat the same exercise, but with \(K = 4\) clusters instead.

ggplot(penguins4) +
  geom_histogram(aes(x = silhouette), binwidth = 0.05) +
  facet_grid(~ .cluster)

Cluster Stability

library("MASS")
library("Matrix")
library("dplyr")
library("ggplot2")
library("pdist")
library("superheat")
library("tidyr")
library(knitr)
theme_set(theme_minimal())
set.seed(1234)
  1. One of the fundamental principles in statistics is that, no matter how the experiment / study was conducted, if we ran it again, we would get different results. More formally, sampling variability creates uncertainty in our inferences.

  2. How should we think about sampling variability in the context of clustering? This is a tricky problem, because you can permute the labels of the clusters without changing the meaning of the clustering. However, it is possible to measure and visualize the stability of a point’s cluster assignment.

  3. To make this less abstract, consider an example. A study has found a collection of genes that are differentially expressed between patients with two different subtypes of a disease. There is an interest in clustering genes that have similar expression profiles across all patients — these genes probably belong to similar biological processes.

  4. Once you run the clustering, how sure can you be that, if the study would run again, you would recover a similar clustering? Are there some genes that you are sure belong to a particular cluster? Are there some that lie between two clusters?

  5. To illustrate, consider the simulated dataset below. Imagine that the rows are patients, the column are genes, and the colors are the expression levels of genes within patients. There are 5 clusters of genes here (columns 1 - 20 are cluster 1, 21 - 41 are cluster 2, …). The first two clusters are only weakly visible, while the last three stand out strongly.

n_per <- 20
p <- n_per * 5
Sigma1 <- diag(2) %x% matrix(rep(0.3, n_per ** 2), nrow = n_per)
Sigma2 <- diag(3) %x% matrix(rep(0.6, n_per ** 2), nrow = n_per)
Sigma <- bdiag(Sigma1, Sigma2)
diag(Sigma) <- 1
mu <- rep(0, 100)
x <- mvrnorm(25, mu, Sigma)
cols <- c('#f6eff7','#bdc9e1','#67a9cf','#1c9099','#016c59')
superheat(
  x, 
  pretty.order.rows = TRUE, 
  bottom.label = "none", 
  heat.pal = cols,
  left.label.text.size = 3,
  legend = FALSE
)
A simulated clustering of genes (columns) across rows (patients).

A simulated clustering of genes (columns) across rows (patients).

  1. The main idea for how to compute cluster stability is to bootstrap (i.e., randomly resample) the patients and see whether the cluster assignments for each gene change. More precisely, we use the following strategy,
  1. The picture below describes the bootstrapping process for a gene. The two rows correspond to the original and bootstrapped representations a specific gene, respectively. Each bar gives the expression level of the gene for one individual. Due to the random sampling in the bootstrapped dataset, some individuals become overrepresented and some are removed. If we also permute the centroids in the same way, we get a new distance between genes and their centroids. Since the patients who are included changes, the distances between each gene and each centroid changes, so the genes might be assigned to different clusters.

K <- 5
B <- 1000
cluster_profiles <- kmeans(t(x), centers = K)$centers
cluster_probs <- matrix(nrow = ncol(x), ncol = B)
for (b in seq_len(B)) {
  b_ix <- sample(nrow(x), replace = TRUE)
  dists <- as.matrix(pdist(t(x[b_ix, ]), cluster_profiles[, b_ix]))
  cluster_probs[, b] <- apply(dists, 1, which.min)
}
cluster_probs <- as_tibble(cluster_probs) %>%
  mutate(gene = row_number()) %>%
  pivot_longer(-gene, names_to = "b", values_to = "cluster")
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
  1. The table below shows the result of this procedure. In each bootstrap iteration, gene 1 was assigned to cluster 4, so we can rely on that assignment. On the other hand, gene 3 is assigned to cluster 4 75% of the time, but occasionally appears in clusters 1, 2, and 5.
cluster_probs <- cluster_probs %>%
  mutate(cluster = as.factor(cluster)) %>%
  group_by(gene, cluster) %>%
  summarise(prob = n() / B)
## `summarise()` has grouped output by 'gene'. You can override using the `.groups` argument.
cluster_probs
## # A tibble: 267 x 3
## # Groups:   gene [100]
##     gene cluster  prob
##    <int> <fct>   <dbl>
##  1     1 2       0.956
##  2     1 3       0.041
##  3     1 5       0.003
##  4     2 2       0.778
##  5     2 5       0.222
##  6     3 1       0.001
##  7     3 2       0.978
##  8     3 5       0.021
##  9     4 1       0.001
## 10     4 2       0.689
## # ... with 257 more rows
  1. These fractions for all genes are summarized by the plot below. Each row is a gene. The length of each color gives the number of times that gene was assigned to that cluster. The genes from rows 41 - 100 are all clearly distinguished, which is in line with what we saw visually in the heatmap above. The first two clusters are somewhat recovered, but since they were often assigned to alternative clusters, we can conclude that they were harder to demarcate out than the others.
ggplot(cluster_probs) +
  geom_bar(aes(y = as.factor(gene), x = prob, col = cluster, fill = cluster), stat = "identity") +
  scale_fill_brewer(palette = "Set2") +
  scale_color_brewer(palette = "Set2") +
  scale_x_continuous(expand = c(0, 0)) +
  labs(y = "Gene", x = "Proportion") +
  theme(
    axis.ticks.y = element_blank(),
    axis.text.y = element_text(size = 7),
    legend.position = "bottom"
  )


  1. You can change cur_id to try different observations.↩︎

  2. This last case likely indicates a misclustering.↩︎